{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fall Recovery Training Notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
    "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
    "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
    "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import json\n",
    "from datetime import datetime\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jp\n",
    "import matplotlib.pyplot as plt\n",
    "import mediapy as media\n",
    "import mujoco\n",
    "import wandb\n",
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from etils import epath\n",
    "from flax.training import orbax_utils\n",
    "from IPython.display import clear_output, display\n",
    "from orbax import checkpoint as ocp\n",
    "from tqdm import tqdm\n",
    "\n",
    "from mujoco_playground import locomotion, wrapper\n",
    "from mujoco_playground.config import locomotion_params\n",
    "\n",
    "\n",
    "# Enable persistent compilation cache.\n",
    "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
    "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
    "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"Go1Getup\"\n",
    "env_cfg = locomotion.get_default_config(env_name)\n",
    "randomizer = locomotion.get_domain_randomizer(env_name)\n",
    "ppo_params = locomotion_params.brax_ppo_config(env_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_cfg.reward_config.scales.dof_vel = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "pprint(ppo_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup wandb logging.\n",
    "USE_WANDB = False\n",
    "\n",
    "if USE_WANDB:\n",
    "  wandb.init(project=\"mjxrl\", config=env_cfg)\n",
    "  wandb.config.update({\n",
    "      \"env_name\": env_name,\n",
    "  })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SUFFIX = None\n",
    "FINETUNE_PATH = None\n",
    "\n",
    "# Generate unique experiment name.\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%Y%m%d-%H%M%S\")\n",
    "exp_name = f\"{env_name}-{timestamp}\"\n",
    "if SUFFIX is not None:\n",
    "  exp_name += f\"-{SUFFIX}\"\n",
    "print(f\"Experiment name: {exp_name}\")\n",
    "\n",
    "# Possibly restore from the latest checkpoint.\n",
    "if FINETUNE_PATH is not None:\n",
    "  FINETUNE_PATH = epath.Path(FINETUNE_PATH)\n",
    "  latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n",
    "  latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]\n",
    "  latest_ckpts.sort(key=lambda x: int(x.name))\n",
    "  latest_ckpt = latest_ckpts[-1]\n",
    "  restore_checkpoint_path = latest_ckpt\n",
    "  print(f\"Restoring from: {restore_checkpoint_path}\")\n",
    "else:\n",
    "  restore_checkpoint_path = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = epath.Path(\"checkpoints\").resolve() / exp_name\n",
    "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
    "print(f\"{ckpt_path}\")\n",
    "\n",
    "with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
    "  json.dump(env_cfg.to_dict(), fp, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_data, y_data, y_dataerr = [], [], []\n",
    "times = [datetime.now()]\n",
    "\n",
    "\n",
    "def progress(num_steps, metrics):\n",
    "  # Log to wandb.\n",
    "  if USE_WANDB:\n",
    "    wandb.log(metrics, step=num_steps)\n",
    "\n",
    "  # Plot.\n",
    "  clear_output(wait=True)\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics[\"eval/episode_reward\"])\n",
    "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
    "\n",
    "  plt.xlim([0, ppo_params.num_timesteps * 1.25])\n",
    "  plt.ylim([0, 75])\n",
    "  plt.xlabel(\"# environment steps\")\n",
    "  plt.ylabel(\"reward per episode\")\n",
    "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
    "\n",
    "  display(plt.gcf())\n",
    "\n",
    "\n",
    "def policy_params_fn(current_step, make_policy, params):\n",
    "  del make_policy  # Unused.\n",
    "  orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
    "  save_args = orbax_utils.save_args_from_target(params)\n",
    "  path = ckpt_path / f\"{current_step}\"\n",
    "  orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
    "\n",
    "training_params = dict(ppo_params)\n",
    "del training_params[\"network_factory\"]\n",
    "\n",
    "train_fn = functools.partial(\n",
    "  ppo.train,\n",
    "  **training_params,\n",
    "  network_factory=functools.partial(\n",
    "      ppo_networks.make_ppo_networks,\n",
    "      **ppo_params.network_factory\n",
    "  ),\n",
    "  restore_checkpoint_path=restore_checkpoint_path,\n",
    "  progress_fn=progress,\n",
    "  wrap_env_fn=wrapper.wrap_for_brax_training,\n",
    "  policy_params_fn=policy_params_fn,\n",
    "  randomization_fn=randomizer,\n",
    ")\n",
    "\n",
    "env = locomotion.load(env_name, config=env_cfg)\n",
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)\n",
    "if len(times) > 1:\n",
    "  print(f\"time to jit: {times[1] - times[0]}\")\n",
    "  print(f\"time to train: {times[-1] - times[1]}\")\n",
    "\n",
    "# Make a final plot of reward and success vs WALLCLOCK time.\n",
    "plt.figure()\n",
    "plt.ylim([0, 75])\n",
    "plt.xlabel(\"wallclock time (s)\")\n",
    "plt.ylabel(\"reward per episode\")\n",
    "plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "plt.errorbar(\n",
    "    [(t - times[0]).total_seconds() for t in times[:-1]],\n",
    "    y_data,\n",
    "    yerr=y_dataerr,\n",
    "    color=\"blue\",\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "# Save normalizer and policy params to the checkpoint dir.\n",
    "normalizer_params, policy_params, value_params = params\n",
    "with open(ckpt_path / \"params.pkl\", \"wb\") as f:\n",
    "  data = {\n",
    "    \"normalizer_params\": normalizer_params,\n",
    "    \"policy_params\": policy_params,\n",
    "    \"value_params\": value_params,\n",
    "  }\n",
    "  pickle.dump(data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_fn = make_inference_fn(params, deterministic=True)\n",
    "jit_inference_fn = jax.jit(inference_fn)\n",
    "\n",
    "env_cfg.drop_from_height_prob = 1.0\n",
    "eval_env = locomotion.load(env_name, config=env_cfg)\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(12345)\n",
    "rollout = []\n",
    "rewards = []\n",
    "torso_height = []\n",
    "actions = []\n",
    "torques = []\n",
    "power = []\n",
    "qfrc_constraint = []\n",
    "qvels = []\n",
    "power1 = []\n",
    "power2 = []\n",
    "for _ in tqdm(range(10)):\n",
    "  rng, reset_rng = jax.random.split(rng)\n",
    "  state = jit_reset(reset_rng)\n",
    "  for i in range(env_cfg.episode_length // 2):\n",
    "    act_rng, rng = jax.random.split(rng)\n",
    "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "    actions.append(ctrl)\n",
    "    state = jit_step(state, ctrl)\n",
    "    rollout.append(state)\n",
    "    rewards.append(\n",
    "        {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
    "    )\n",
    "    torso_height.append(state.data.qpos[2])\n",
    "    torques.append(state.data.actuator_force)\n",
    "    qvel = state.data.qvel[6:]\n",
    "    power.append(jp.sum(jp.abs(qvel * state.data.actuator_force)))\n",
    "    qfrc_constraint.append(jp.linalg.norm(state.data.qfrc_constraint[6:]))\n",
    "    qvels.append(jp.max(jp.abs(qvel)))\n",
    "    frc = state.data.actuator_force\n",
    "    qvel = state.data.qvel[6:]\n",
    "    power1.append(jp.sum(frc * qvel))\n",
    "    power2.append(jp.sum(jp.abs(frc * qvel)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "power = jp.array(power1)\n",
    "# Plot smoothed power.\n",
    "power = jp.convolve(power, jp.ones(10) / 10, mode='valid')\n",
    "plt.plot(power)\n",
    "# Plot mean as a horizontal line.\n",
    "plt.axhline(y=jp.mean(jp.array(power)), color='r', linestyle='-', label='Mean')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "# Print min and max.\n",
    "print(f\"Min power: {jp.min(power)}, Max power: {jp.max(power)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(qvels)\n",
    "plt.axhline(2 * jp.pi, color=\"red\", label=\"limit\")\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_every = 2\n",
    "fps = 1.0 / eval_env.dt / render_every\n",
    "traj = rollout[::render_every]\n",
    "\n",
    "scene_option = mujoco.MjvOption()\n",
    "scene_option.geomgroup[2] = True\n",
    "scene_option.geomgroup[3] = False\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
    "\n",
    "frames = eval_env.render(\n",
    "    traj, camera=\"side\", scene_option=scene_option, height=480, width=640\n",
    ")\n",
    "media.show_video(frames, fps=fps, loop=False)\n",
    "# media.write_video(f\"{env_name}.mp4\", frames, fps=fps, qp=18)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
